-
Notifications
You must be signed in to change notification settings - Fork 5
make realnvp and nsf layers as part of the pkg #53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
After the recent Mooncake update ( using Random, Distributions, LinearAlgebra
using Bijectors
using Bijectors: partition, combine, PartitionMask
using Mooncake, Enzyme, ADTypes
import DifferentiationInterface as DI
# just define a MLP
function mlp3(
input_dim::Int,
hidden_dims::Int,
output_dim::Int;
activation=Flux.leakyrelu,
paramtype::Type{T} = Float64
) where {T<:AbstractFloat}
m = Chain(
Flux.Dense(input_dim, hidden_dims, activation),
Flux.Dense(hidden_dims, hidden_dims, activation),
Flux.Dense(hidden_dims, output_dim),
)
return Flux._paramtype(paramtype, m)
end
inputdim = 4
mask_idx = 1:2:inputdim
# creat a masking layer
mask = PartitionMask(inputdim, mask_idx)
cdim = length(mask_idx)
x = randn(inputdim)
t_net = mlp3(cdim, 16, cdim; paramtype = Float64)
ps, st = Optimisers.destructure(t_net) the following code runs perfectly function loss(ps, st, x, mask)
t_net = st(ps)
x₁, x₂, x₃ = partition(mask, x)
y₁ = x₁ .+ t_net(x₂)
y = combine(mask, y₁, x₂, x₃)
# println("y = ", y)
return sum(abs2, y)
end
loss(ps, st, x, mask) # return 3.0167880799441793
val, grad = DI.value_and_gradient(
ls_msk,
ADTypes.AutoMooncake(; config = Mooncake.Config()),
ps, DI.Cache(st), DI.Constant(x), DI.Constant(mask)
) but autograd fails if I wrap struct ACL
mask::Bijectors.PartitionMask
t::Flux.Chain
end
@functor ACL (t, )
acl = ACL(mask, t_net)
psacl, stacl = Optimisers.destructure(acl)
function loss_acl(ps, st, x)
acl = st(ps)
t_net = acl.t
mask = acl.mask
x₁, x₂, x₃ = partition(mask, x)
y₁ = x₁ .+ t_net(x₂)
y = combine(mask, y₁, x₂, x₃)
return sum(abs2, y)
end
loss_acl(psacl, stacl, x) # return 3.0167880799441793
val, grad = DI.value_and_gradient(
loss_acl,
ADTypes.AutoMooncake(; config = Mooncake.Config()),
psacl, DI.Cache(stacl), DI.Constant(x)
) with error message
val, grad = DI.value_and_gradient(
loss_acl,
ADTypes.AutoEnzyme(;
mode=Enzyme.set_runtime_activity(Enzyme.Reverse),
function_annotation=Enzyme.Const,
),
psacl, DI.Cache(stacl), DI.Constant(x)
) with output
Any thoughts on this @yebai @willtebbutt? |
Ah looks like it only has issue when the part of the fields in the structure is annotated by struct Holder
t::Flux.Chain
end
@functor Holder
psh, sth = Optimisers.destructure(Holder(t_net))
function loss2(ps, st, x)
holder = st(ps)
t_net = holder.t
y = x .+ t_net(x)
return sum(abs2, y)
end
loss2(psh, sth, x) # return 7.408352005690478
val, grad = DI.value_and_gradient(
loss2,
ADTypes.AutoMooncake(; config = Mooncake.Config()),
psh, DI.Cache(sth), DI.Constant(x)
) with outputs
|
@zuhengxu, can you help bisect which Mooncake version / Julia version introduced this bug? |
Good point! I'll look at this today. |
It appears that the remaining issues with Mooncake are minor, likely due to a lack of a specific rule. @sunxd3, can you help if it requires a new rule? |
I'll look into it 👍 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi! I have some minor suggestions
NormalizingFlows.jl documentation for PR #53 is available at: |
Thank you @yebai @sunxd3 @Red-Portal again for the help and comments in the process of this PR! Let me know if this PR looks good to you and I'll merge it afterwards. |
Sorry for the delay. Reviewing a paper by JMLR has been taking up all my bandwidth. I'll take a deeper look tomorrow. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I only have minor suggestions. Feel free to take a look and apply them if you agree. Otherwise, looks good to me.
sorry for missing the tag, allow me to give a look later today or tomorrow |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
couple of tiny things, very happy to do another round of review
Thank you @sunxd3 @Red-Portal for the review! I made the corresponding updates and let me know if the current version looks good to you! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
another couple of tiny things, nothing major beyond these
pretty much good to go from my end, but let's wait for Kyurae to take a look? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I only have a few minor comments that should be quick to handle!
@Red-Portal @sunxd3 Let me know if I can hit the big green button! Thanks for the quick feedback. |
Alright looks good to me now! |
As discussed in #36 (see #36 (comment)), I'm moving the
AffineCoupling
andNeuralSplineLayer
from the example tosrc/
so it can be called.AffineCoupling
andNeuralSplineLayer
intosrc
realnvp
and aneuralsplineflow
constructor. For therealnvp
, follow the default architecture as mentioned in Advances in Black-Box VI: Normalizing Flows, Importance Weighting, and Optimization.